{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "089ba8a0-b443-4330-81ff-94b6d5cebc3c",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "F:\\programData\\conda\\env\\ChatGLM2-6B\\lib\\site-packages\\numpy\\_distributor_init.py:30: UserWarning: loaded more than 1 DLL from .libs:\n",
      "F:\\programData\\conda\\env\\ChatGLM2-6B\\lib\\site-packages\\numpy\\.libs\\libopenblas.FB5AE2TYXYH2IJRDKGDGQ3XBKLKTF43H.gfortran-win_amd64.dll\n",
      "F:\\programData\\conda\\env\\ChatGLM2-6B\\lib\\site-packages\\numpy\\.libs\\libopenblas64__v0.3.23-246-g3d31191b-gcc_10_3_0.dll\n",
      "  warnings.warn(\"loaded more than 1 DLL from .libs:\"\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "tensor([[-0.1512, -0.0015, -0.0737, -0.0992,  0.1091, -0.0365, -0.1498,  0.1453,\n",
       "         -0.2898, -0.0201],\n",
       "        [-0.1787, -0.0086, -0.2511, -0.0150,  0.0475,  0.0211, -0.1390,  0.1056,\n",
       "         -0.2131, -0.0812]], grad_fn=<AddmmBackward0>)"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "from torch import nn\n",
    "from torch.nn import functional as F\n",
    "\n",
    "net = nn.Sequential(nn.Linear(20,256),\n",
    "                    nn.ReLU(),\n",
    "                    nn.Linear(256,10))\n",
    "x = torch.rand(2,20)\n",
    "net(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "e5b6738c-fe71-4747-ad0f-34c17f81b9f1",
   "metadata": {},
   "outputs": [],
   "source": [
    "#5.1.1. 自定义块\n",
    "class MLP(nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.hidden = nn.Linear(20,256)\n",
    "        self.out = nn.Linear(256,10)\n",
    "\n",
    "    def forward(self,x):\n",
    "        return self.out(F.relu(self.hidden(x)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "e122dac2-8d31-4445-b067-fd52f3f4c99b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 0.0008, -0.0502, -0.0672,  0.0350, -0.1263, -0.1050,  0.2047, -0.1076,\n",
       "          0.1010,  0.1496],\n",
       "        [ 0.0129, -0.1245,  0.0905,  0.0268, -0.0819, -0.1487,  0.0670, -0.1754,\n",
       "          0.0682,  0.1758]], grad_fn=<AddmmBackward0>)"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "net = MLP()\n",
    "net(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "0ed9e9da-e02e-48c9-81bc-019add8ab299",
   "metadata": {},
   "outputs": [],
   "source": [
    "#5.1.2. 顺序块\n",
    "class MySequential(nn.Module):\n",
    "    def __init__(self,*args):\n",
    "        super().__init__()\n",
    "        for idx,module in enumerate(args):\n",
    "            self._modules[str(idx)] = module\n",
    "\n",
    "    def forward(self,x):\n",
    "        for block in self._modules.values():\n",
    "            x = block(x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "b2ca851f-895b-4340-b49e-5540b96ae3c6",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 0.0608,  0.0552, -0.0004,  0.2022,  0.0738, -0.1087,  0.2003,  0.1119,\n",
       "         -0.1256, -0.1130],\n",
       "        [ 0.0350, -0.0791, -0.0290,  0.2728,  0.1093, -0.0092,  0.2633,  0.1198,\n",
       "         -0.1491,  0.0022]], grad_fn=<AddmmBackward0>)"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "net = MySequential(nn.Linear(20,256),nn.ReLU(),nn.Linear(256,10))\n",
    "net(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "cd85c1c1-0378-4109-bb12-9ae8a84f81ae",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 5.1.3. 在前向传播函数中执行代码\n",
    "class FixedHiddenMLP(nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.rand_weight = torch.rand((20,20),requires_grad =False)\n",
    "        self.linear = nn.Linear(20,20)\n",
    "\n",
    "    def forward(self,x):\n",
    "        x = self.linear(x)\n",
    "        x = F.relu(torch.mm(x,self.rand_weight) +1)\n",
    "        x = self.linear(x)\n",
    "        while x.abs().sum() > 1:\n",
    "            x /= 2\n",
    "        return x.sum()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "a9ba9070-2593-4905-a53f-3a501313eaa0",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(0.0374, grad_fn=<SumBackward0>)"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "net = FixedHiddenMLP()\n",
    "net(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "39037160-f935-4260-8665-582b4d412421",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(-0.2122, grad_fn=<SumBackward0>)"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "class NestMLP(nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.net = nn.Sequential(\n",
    "            nn.Linear(20,64),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(64,32),\n",
    "            nn.ReLU())\n",
    "        self.linear = nn.Linear(32,16)\n",
    "    def forward(self,x):\n",
    "        return self.linear(self.net(x))\n",
    "\n",
    "chimera = nn.Sequential(NestMLP(),nn.Linear(16,20),FixedHiddenMLP())\n",
    "chimera(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8545f742-0a73-4c61-aa3f-78cac31ef182",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.17"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
